# clear existing user defined variables
for element in dir():
    if element[0:2] != "__":
        del globals()[element]

import os
import pickle
import argparse
from time import time

import numpy as np
import tensorflow as tf

from functions_preprocessing import english_standardization, visualize_prediction
from functions_postprocessing import remove_serial_duplicates, remove_specific_token
from models import build_model
from matplotlib import pyplot as plt


# =================================================================
# =================================================================
# KEYWORD SPOTTING USING CNN-RNN-CTC
# TEST PREDICTION AND EVALUATE PERFORMANCE
# =================================================================
# =================================================================


#python kws_predict.py -r results_model02_01

parser = argparse.ArgumentParser()
# parser.add_argument("-d","--datafolder",default='examples/data',type=str,help="Path to data folder - generated by `prepare_dataset.py`.")
# parser.add_argument("-r","--resultfolder",default='examples/result_model02_noise',type=str,help="Path to result folder.")
parser.add_argument("-d","--datafolder",default='data_CTC_pre',type=str,help="Path to data folder - generated by `prepare_dataset.py`.")
parser.add_argument("-r","--resultfolder",default='result11',type=str,help="Path to result folder.")


parser.add_argument("-n","--num_samples",default=4,type=int,help="Number of samples to compute prediction.")
parser.add_argument("-o","--offset",default=0,type=int,help="Offset number of samples to skip in the dataset.")
parser.add_argument("--noise_aug",default=0.0,type=float,help="Noise augmentation rate (from 0 to 1).")
args = parser.parse_args()

# select model to load
#resultfolder, model_name, model_num = 'results_model01_01', 'model-48-2.101.h5', 1
#resultfolder, model_name, model_num, feature_type = 'results_model02_01', 'model-068-1.842.h5', 2, 'mfcc'
#resultfolder, model_name, model_num, feature_type = 'results_model02_noise', 'model-099-1.596.h5', 2, 'mfcc'
#resultfolder, model_name, model_num, feature_type = 'results_model02_02', 'model-068-2.431.h5', 2, 'mfcc'
#resultfolder, model_name, model_num, feature_type = 'results_model03_01', 'model-064-2.599.h5', 3, 'mfcc'
#resultfolder, model_name, model_num, feature_type = 'results_model04_01', 'model-065-2.028.h5', 4, 'mfcc'
#resultfolder, model_name, model_num, feature_type = 'results_model05_01', 'model-069-2.554.h5', 5, 'melspec'
#resultfolder, model_name, model_num, feature_type = 'results_model06_01', 'model-069-2.847.h5', 6, 'melspec'

#datafolder = 'data'
#subset = 3


# --------------------------------------------
#  L O A D   P A R A M E T E R S
# --------------------------------------------

with open(os.path.join(args.datafolder,"parameters.pickle"), 'rb') as f:
    dataset_params = pickle.load(f)
with open(os.path.join(args.resultfolder,"parameters.pickle"), 'rb') as f:
    train_params = pickle.load(f)
with open(os.path.join(args.datafolder,"noise_aug.pickle"), 'rb') as f:
    noise_settings = pickle.load(f)

keywords = dataset_params['keywords']
num_kwd = dataset_params['num_kwd']

# dataset partitioning and shuffling
model_num = train_params['model_num']
feature_type = train_params['feature_type']
train_test_split = train_params['train_test_split']
rng_seed = train_params['rng_seed']
noise_aug = train_params['noise_aug']

# STFT SPECIFICATION
frame_length = train_params['frame_length']
frame_step = train_params['frame_step']
fft_length = train_params['fft_length']
# MFCC SPECIFICATION
lower_freq = train_params['lower_freq']
upper_freq = train_params['upper_freq']
n_mel_bins = train_params['n_mel_bins']
n_mfcc_bins = train_params['n_mfcc_bins']

if feature_type=='mfcc':
    feat_dim = n_mfcc_bins
elif feature_type=='melspec':
    feat_dim = n_mel_bins

text_processor = tf.keras.layers.experimental.preprocessing.TextVectorization(
    standardize=english_standardization,
    max_tokens=None,
    vocabulary=keywords)


# =====================================================
print('\n(1) LOAD DATASET AND MODEL')
# =====================================================

# (A) load datasets (input paths and target texts)

with open(os.path.join(args.datafolder,"speech_commands_dict.pickle"), 'rb') as f:
    data1 = pickle.load(f)
with open(os.path.join(args.datafolder,"speech_commands_edit_dict.pickle"), 'rb') as f:
    data2 = pickle.load(f)
with open(os.path.join(args.datafolder,"librispeech_dict.pickle"), 'rb') as f:
    data3 = pickle.load(f)

input_paths = data1['input_path'] + data2['input_path']# + data3['input_path']
target_texts = data1['target_text'] + data2['target_text']# + data3['target_text']

#print(input_paths)
#print(target_texts)
# (B) Shuffle (based on previously recorded seed)
input_paths=["test/shanghai.wav","test/beijing.wav","test/shengteng.wav","test/beijing_1.wav","test/0182-昇腾-中国-城市-北京.wav"]
target_texts=["上海","北京","昇腾","北京","昇腾 中国 城市 北京"]
num_samples = len(input_paths)
print(num_samples)
np.random.seed(rng_seed)
p = np.random.permutation(num_samples)
print("**********",p)
input_paths = [ input_paths[i] for i in p ]
target_texts = [ target_texts[i] for i in p ]

# (C) Create Datasets

s1 = args.offset
s2 = args.offset + args.num_samples

ds = visualize_prediction(input_paths[s1:s2], target_texts[s1:s2], noise_settings=noise_settings,
    noise_prob=noise_aug, feature_type=feature_type, vectorizer=text_processor,
    sr=dataset_params['sampling_rate'], frame_len=frame_length, frame_hop=frame_step,
    fft_len=fft_length, num_mel_bins=n_mel_bins, lower_freq=lower_freq, upper_freq=upper_freq,
    num_mfcc=n_mfcc_bins)

# (D) load model

_, model_pred = build_model(model_num, feat_dim, num_kwd)
model_pred.load_weights(os.path.join(args.resultfolder,'model_weights.h5'))


# =====================================================
print('\n(4) PREDICT AND EVALUATE PERFORMANCE')
# =====================================================

# for i, (audio_batch, feats_batch, text_batch) in enumerate(ds):
#     print(audio_batch.shape)
#     print(feats_batch.shape)
#     print(text_batch.shape)

for i, (audio_batch, feats_batch, text_batch) in enumerate(ds):
    # (a) prediction: probabilty sequence of each token
    # tensor of shape ( N, seq_len, NUM_KWD+2 )
    token_prob = model_pred(feats_batch)
    # (b) categorical prediction: take the most probable token at each time instant
    # tensor of shape ( N, seq_len )
    tokens_pred = tf.argmax(token_prob,axis=-1)
    # (c) posterior handling 1: remove duplicates
    tokens_post = remove_serial_duplicates(tokens_pred)

    # (d) posterior handling 2: remove null token (of CTC)
    tokens_post_p1 = remove_specific_token(tokens_post, num_kwd+1)
    tokens_true_p1 = remove_specific_token(text_batch.numpy(), num_kwd+1)
    print(tokens_post_p1)
    print(tokens_true_p1)
    print("-------------------------------")
    # # (e) posterior handling 3: remove [unk] token
    # # (precision and recall require this)
    # tokens_post_p2 = remove_specific_token(tokens_post_p1, 0)
    # tokens_true_p2 = remove_specific_token(tokens_true_p1, 0)
    """
    plt.figure(i+1)

    plt.subplot(2,1,1); plt.grid()
    for j in range(0,num_kwd+1):
        plt.plot(token_prob[0,:,j])
    plt.plot(token_prob[0,:,num_kwd+1],'--',color='black',linewidth=1) # lime, springgreen, khaki
    plt.ylabel('Token probability'); plt.xlabel('Output samples')
    plt.title( 'True tokens: ' + str(tokens_true_p1[0]) + '\nPredicted: ' + str(tokens_post_p1[0]) )
    plt.legend(['0 (filler)','1 (shanghai)','2 (beijing)'], bbox_to_anchor=(1.0,1.2))

    plt.subplot(2,1,2); plt.grid()
    plt.plot(audio_batch[0])
    plt.ylabel('Speech signal'); plt.xlabel('Audio samples')
    plt.subplots_adjust(bottom=0.1, top=0.9, left=0.11, right=0.8, hspace=0.3)

#print("*****",target_texts[tokens_post_p1[0]])

plt.show()

"""
# ==================================================================